import os
from dataclasses import dataclass, field
from typing import List
import random
import copy
import cv2
import torch
import albumentations as A
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import numpy as np
from torch import nn, optim
CPU_DEVICE = 'cpu'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# DEVICE = 'cpu'
DEVICE
'cuda'
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
IS_BIG_STRIDE = True # if using images with big stride (768 instead of 256 pixels). Faster training and validation/test on random samples from all files
# TILES_BASE_DIR = "/media/data/local/corn/processed"
TILES_BASE_DIR = "/media/data/local/corn/processed_stride768"
SUBDIRECTORIES_TO_PROCESS_TRAIN = [
"kukurydza_5_ha",
"kukurydza_10_ha",
"kukurydza_11_ha",
"kukurydza_13_ha",
"kukurydza_15_ha",
"kukurydza_18_ha",
"kukurydza_25_ha",
"kukurydza_38_ha",
"kukurydza_60_ha",
]
# if TEST or VALIDATIONS are empty, random part of training set will be used
SUBDIRECTORIES_TO_PROCESS_VALID = [
# "kukurydza_10_ha",
]
SUBDIRECTORIES_TO_PROCESS_TEST = [
]
UNCROPPED_TILE_SIZE = (512 + 256) # in pixels
CROPPED_TILE_SIZE = 512
CROP_TILE_MARGIN = (UNCROPPED_TILE_SIZE - CROPPED_TILE_SIZE) // 2
@dataclass
class TilesPaths:
img_paths: List = field(default_factory=lambda:[])
mask_paths: List = field(default_factory=lambda:[])
def get_tile_paths_for_directories(directories, shuffle=True) -> TilesPaths:
tile_paths = TilesPaths()
for dir_name in directories:
dir_path = os.path.join(TILES_BASE_DIR, dir_name)
file_names = [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
mask_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'mask' in f])
img_files_prefixes = set([f[:f.rfind('_')] for f in file_names if 'img' in f])
common_files_prefixes = mask_files_prefixes.intersection(img_files_prefixes)
all_files_prefixes = mask_files_prefixes.union(img_files_prefixes)
missing_files_prefixes = all_files_prefixes - common_files_prefixes
if missing_files_prefixes:
raise Exception(f"Some files don't have correponding pair in mask/image: {missing_files_prefixes}")
common_files_prefixes = list(common_files_prefixes)
if shuffle:
random.shuffle(common_files_prefixes)
for file_prefix in common_files_prefixes:
img_file_name = file_prefix + '_img.png'
mask_file_name = file_prefix + '_mask.png'
tile_paths.img_paths.append(os.path.join(dir_path, img_file_name))
tile_paths.mask_paths.append(os.path.join(dir_path, mask_file_name))
return tile_paths
if SUBDIRECTORIES_TO_PROCESS_VALID and SUBDIRECTORIES_TO_PROCESS_TEST:
# we have valid tiles for test/valid
tile_paths_train = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_TRAIN)
tile_paths_valid = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_VALID)
tile_paths_test = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_TEST)
else:
tile_paths_all = get_tile_paths_for_directories(SUBDIRECTORIES_TO_PROCESS_TRAIN, shuffle=True)
N = len(tile_paths_all.img_paths)
sp = [int(N*0.8), int(N*0.9)] # dataset split points
tile_paths_train = TilesPaths(img_paths=tile_paths_all.img_paths[:sp[0]], mask_paths=tile_paths_all.mask_paths[:sp[0]])
tile_paths_valid = TilesPaths(img_paths=tile_paths_all.img_paths[sp[0]:sp[1]], mask_paths=tile_paths_all.mask_paths[sp[0]:sp[1]])
tile_paths_test = TilesPaths(img_paths=tile_paths_all.img_paths[sp[1]:], mask_paths=tile_paths_all.mask_paths[sp[1]:])
print(f'Number of tiles train = {len(tile_paths_train.img_paths)}')
print(f'Number of tiles validation = {len(tile_paths_valid.img_paths)}')
print(f'Number of tiles test = {len(tile_paths_test.img_paths)}')
N
Number of tiles train = 2412 Number of tiles validation = 301 Number of tiles test = 302
3015
SEGMENTATION_CLASS_VALUES = [0, 255, 127]
NUMBER_OF_SEGMENTATION_CLASSES = len(SEGMENTATION_CLASS_VALUES)
class CornFieldDamageDataset(torch.utils.data.Dataset):
def __init__(self, img_file_paths, mask_file_paths, augment=True):
self.img_file_paths = img_file_paths
self.mask_file_paths = mask_file_paths
assert(len(self.img_file_paths) == len(mask_file_paths))
if augment:
self._img_and_mask_transform = self._get_img_and_mask_augmentation_tranform() # augmentation transform
else:
self._img_and_mask_transform = self._get_img_and_mask_crop_tranform() # crop only transform
def __len__(self):
return len(self.mask_file_paths)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = cv2.imread(self.img_file_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # not really needed I guess
mask = cv2.imread(self.mask_file_paths[idx], cv2.IMREAD_GRAYSCALE)
transformed = self._img_and_mask_transform(image=image, mask=mask)
image, mask = transformed['image'], transformed['mask']
masks = [(mask == v) for v in SEGMENTATION_CLASS_VALUES]
mask_stacked = np.stack(masks, axis=0).astype('float')
image = image.astype('float')
image /= 255
image = image.transpose(2, 0, 1)
return image.astype('float32'), mask_stacked.astype('float32')
def _get_img_and_mask_augmentation_tranform(self):
# Declare an augmentation pipeline
transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomScale(scale_limit=0.15), # above scale 0.16 images are too small
A.Rotate(limit=90), # degrees
A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN),
])
# TODO - color, contrast, gamma, randomShadow, rain
return transform
def _get_img_and_mask_crop_tranform(self):
transform = A.Compose([
A.Crop(x_min=CROP_TILE_MARGIN, y_min=CROP_TILE_MARGIN, x_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN, y_max=UNCROPPED_TILE_SIZE-CROP_TILE_MARGIN),
])
return transform
train_dataset = CornFieldDamageDataset(img_file_paths=tile_paths_train.img_paths, mask_file_paths=tile_paths_train.mask_paths)
valid_dataset = CornFieldDamageDataset(img_file_paths=tile_paths_valid.img_paths, mask_file_paths=tile_paths_valid.mask_paths, augment=False)
test_dataset = CornFieldDamageDataset(img_file_paths=tile_paths_test.img_paths, mask_file_paths=tile_paths_test.mask_paths, augment=False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=6, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=6, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=6, shuffle=True)
len(train_loader)
402
# image, mask = train_dataset[222] # get some sample
# plt.imshow(mask[0, :, :])
# plt.show()
# plt.imshow(mask[1, :, :])
# plt.show()
# plt.imshow(mask[2, :, :])
# plt.show()
# plt.imshow(image.transpose(1, 2, 0))
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=NUMBER_OF_SEGMENTATION_CLASSES, # model output channels (number of classes in your dataset)
activation='softmax2d', # ?
)
print(model)
Unet(
(encoder): ResNetEncoder(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(decoder): UnetDecoder(
(center): Identity()
(blocks): ModuleList(
(0): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(1): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(2): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(3): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
(4): DecoderBlock(
(conv1): Conv2dReLU(
(0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention1): Attention(
(attention): Identity()
)
(conv2): Conv2dReLU(
(0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(attention2): Attention(
(attention): Identity()
)
)
)
)
(segmentation_head): SegmentationHead(
(0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Identity()
(2): Activation(
(activation): Softmax(dim=1)
)
)
)
# criterion = nn.CrossEntropyLoss()
loss = smp.utils.losses.DiceLoss()
metrics = [
smp.utils.metrics.IoU(threshold=0.5, name='IoU'),
smp.utils.metrics.IoU(threshold=0.5, ignore_channels=[1, 2], name='IoU-0'),
smp.utils.metrics.IoU(threshold=0.5, ignore_channels=[0, 2], name='IoU-1'),
smp.utils.metrics.IoU(threshold=0.5, ignore_channels=[0, 1], name='IoU-2'),
smp.utils.metrics.Fscore(threshold=0.5, ignore_channels=[2]),
smp.utils.metrics.Accuracy(threshold=0.5, ignore_channels=[2]),
smp.utils.metrics.Recall(threshold=0.5, ignore_channels=[2]),
smp.utils.metrics.Precision(threshold=0.5, ignore_channels=[2]),
]
# optimizer = optim.SGD(model_fnn.parameters(), lr=0.0001, momentum=0.9)
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=0.000012), # 0.0001 # 000003 gives 80 epoch for 768 stride
])
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
for e in [valid_epoch, train_epoch]:
e.metrics[1].__name__="IoU_Class0"
e.metrics[2].__name__="IoU_Class1"
e.metrics[3].__name__="IoU_Class2"
max_score = 0
train_logs_vec = []
valid_logs_vec = []
best_model = None
for i in range(0, 40):
print(f'\nEpoch: {i}')
train_logs_vec.append(train_epoch.run(train_loader))
valid_logs = valid_epoch.run(valid_loader)
valid_logs_vec.append(valid_logs)
if max_score < valid_logs['iou_score']:
max_score = valid_logs['iou_score']
model.to(CPU_DEVICE)
best_model = copy.deepcopy(model)
model.to(DEVICE)
if i == 15:
optimizer.param_groups[0]['lr'] = 1e-5
print('Decrease decoder learning rate to 1e-5!')
Epoch: 0 train: 100%|█| 402/402 [02:34<00:00, 2.60it/s, dice_loss - 0.5873, iou_score - 0.2957, IoU_Class0 - 0.3004, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.46it/s, dice_loss - 0.4403, iou_score - 0.6907, IoU_Class0 - 0.715, IoU_Class1 Epoch: 1 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.4183, iou_score - 0.6087, IoU_Class0 - 0.6489, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.46it/s, dice_loss - 0.3075, iou_score - 0.8201, IoU_Class0 - 0.8552, IoU_Class Epoch: 2 train: 100%|█| 402/402 [02:33<00:00, 2.63it/s, dice_loss - 0.3287, iou_score - 0.7256, IoU_Class0 - 0.7898, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.48it/s, dice_loss - 0.2367, iou_score - 0.8926, IoU_Class0 - 0.9357, IoU_Class Epoch: 3 train: 100%|█| 402/402 [02:35<00:00, 2.59it/s, dice_loss - 0.2608, iou_score - 0.8108, IoU_Class0 - 0.8797, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.33it/s, dice_loss - 0.1983, iou_score - 0.89, IoU_Class0 - 0.9346, IoU_Class1 Epoch: 4 train: 100%|█| 402/402 [02:35<00:00, 2.59it/s, dice_loss - 0.2121, iou_score - 0.8295, IoU_Class0 - 0.8933, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.37it/s, dice_loss - 0.1401, iou_score - 0.915, IoU_Class0 - 0.9512, IoU_Class1 Epoch: 5 train: 100%|█| 402/402 [02:33<00:00, 2.63it/s, dice_loss - 0.1745, iou_score - 0.8443, IoU_Class0 - 0.903, IoU_Clas valid: 100%|█| 51/51 [00:09<00:00, 5.48it/s, dice_loss - 0.1399, iou_score - 0.891, IoU_Class0 - 0.9359, IoU_Class1 Epoch: 6 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.1466, iou_score - 0.8558, IoU_Class0 - 0.9104, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.46it/s, dice_loss - 0.1045, iou_score - 0.9137, IoU_Class0 - 0.9499, IoU_Class Epoch: 7 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.1288, iou_score - 0.8594, IoU_Class0 - 0.9127, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.50it/s, dice_loss - 0.089, iou_score - 0.9178, IoU_Class0 - 0.9532, IoU_Class1 Epoch: 8 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.1156, iou_score - 0.8646, IoU_Class0 - 0.9157, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.49it/s, dice_loss - 0.08147, iou_score - 0.9136, IoU_Class0 - 0.9506, IoU_Clas Epoch: 9 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.1044, iou_score - 0.8692, IoU_Class0 - 0.9188, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.49it/s, dice_loss - 0.0742, iou_score - 0.9118, IoU_Class0 - 0.9489, IoU_Class Epoch: 10 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.09782, iou_score - 0.8697, IoU_Class0 - 0.919, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.46it/s, dice_loss - 0.07094, iou_score - 0.9118, IoU_Class0 - 0.9484, IoU_Clas Epoch: 11 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.09061, iou_score - 0.8714, IoU_Class0 - 0.9199, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.49it/s, dice_loss - 0.06445, iou_score - 0.9109, IoU_Class0 - 0.9485, IoU_Clas Epoch: 12 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.08576, iou_score - 0.8732, IoU_Class0 - 0.9204, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.49it/s, dice_loss - 0.05965, iou_score - 0.9163, IoU_Class0 - 0.9527, IoU_Clas Epoch: 13 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.08148, iou_score - 0.8743, IoU_Class0 - 0.9216, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.49it/s, dice_loss - 0.06075, iou_score - 0.9094, IoU_Class0 - 0.9483, IoU_Clas Epoch: 14 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.0764, iou_score - 0.8791, IoU_Class0 - 0.9247, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.48it/s, dice_loss - 0.05544, iou_score - 0.9147, IoU_Class0 - 0.9517, IoU_Clas Epoch: 15 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.07596, iou_score - 0.877, IoU_Class0 - 0.9226, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.44it/s, dice_loss - 0.05142, iou_score - 0.9165, IoU_Class0 - 0.9527, IoU_Clas Decrease decoder learning rate to 1e-5! Epoch: 16 train: 100%|█| 402/402 [02:33<00:00, 2.61it/s, dice_loss - 0.07165, iou_score - 0.8802, IoU_Class0 - 0.9257, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.50it/s, dice_loss - 0.05021, iou_score - 0.9171, IoU_Class0 - 0.9508, IoU_Clas Epoch: 17 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.0685, iou_score - 0.8838, IoU_Class0 - 0.9275, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.46it/s, dice_loss - 0.04897, iou_score - 0.9163, IoU_Class0 - 0.9512, IoU_Clas Epoch: 18 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.06818, iou_score - 0.8831, IoU_Class0 - 0.9268, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.45it/s, dice_loss - 0.05354, iou_score - 0.9092, IoU_Class0 - 0.9459, IoU_Clas Epoch: 19 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.06649, iou_score - 0.8851, IoU_Class0 - 0.928, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.48it/s, dice_loss - 0.04864, iou_score - 0.9159, IoU_Class0 - 0.9513, IoU_Clas Epoch: 20 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.06522, iou_score - 0.8856, IoU_Class0 - 0.9288, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.48it/s, dice_loss - 0.05369, iou_score - 0.9055, IoU_Class0 - 0.9451, IoU_Clas Epoch: 21 train: 100%|█| 402/402 [02:33<00:00, 2.62it/s, dice_loss - 0.06462, iou_score - 0.886, IoU_Class0 - 0.9288, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.52it/s, dice_loss - 0.04446, iou_score - 0.9208, IoU_Class0 - 0.9548, IoU_Clas Epoch: 22 train: 100%|█| 402/402 [02:32<00:00, 2.63it/s, dice_loss - 0.06411, iou_score - 0.886, IoU_Class0 - 0.9289, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.05131, iou_score - 0.9081, IoU_Class0 - 0.9473, IoU_Clas Epoch: 23 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.06293, iou_score - 0.8873, IoU_Class0 - 0.9301, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.04786, iou_score - 0.9133, IoU_Class0 - 0.9506, IoU_Clas Epoch: 24 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.06275, iou_score - 0.8874, IoU_Class0 - 0.9295, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.50it/s, dice_loss - 0.04956, iou_score - 0.9103, IoU_Class0 - 0.9438, IoU_Clas Epoch: 25 train: 100%|█| 402/402 [02:32<00:00, 2.63it/s, dice_loss - 0.06114, iou_score - 0.8897, IoU_Class0 - 0.9314, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.04166, iou_score - 0.9235, IoU_Class0 - 0.9563, IoU_Clas Epoch: 26 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.06018, iou_score - 0.8909, IoU_Class0 - 0.9321, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.04696, iou_score - 0.914, IoU_Class0 - 0.9504, IoU_Class Epoch: 27 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.06039, iou_score - 0.8903, IoU_Class0 - 0.9315, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.54it/s, dice_loss - 0.0526, iou_score - 0.9043, IoU_Class0 - 0.9439, IoU_Class Epoch: 28 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.06029, iou_score - 0.8901, IoU_Class0 - 0.9315, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.54it/s, dice_loss - 0.04919, iou_score - 0.9108, IoU_Class0 - 0.9486, IoU_Clas Epoch: 29 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.0592, iou_score - 0.8914, IoU_Class0 - 0.9327, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.52it/s, dice_loss - 0.05266, iou_score - 0.9025, IoU_Class0 - 0.9443, IoU_Clas Epoch: 30 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.05799, iou_score - 0.8936, IoU_Class0 - 0.9339, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.54it/s, dice_loss - 0.04735, iou_score - 0.9126, IoU_Class0 - 0.9495, IoU_Clas Epoch: 31 train: 100%|█| 402/402 [02:32<00:00, 2.63it/s, dice_loss - 0.05976, iou_score - 0.8905, IoU_Class0 - 0.9311, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.53it/s, dice_loss - 0.05124, iou_score - 0.9062, IoU_Class0 - 0.9429, IoU_Clas Epoch: 32 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.05839, iou_score - 0.8926, IoU_Class0 - 0.9324, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.04484, iou_score - 0.9161, IoU_Class0 - 0.9516, IoU_Clas Epoch: 33 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.05904, iou_score - 0.8914, IoU_Class0 - 0.9319, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.53it/s, dice_loss - 0.05306, iou_score - 0.9028, IoU_Class0 - 0.9415, IoU_Clas Epoch: 34 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.05785, iou_score - 0.8931, IoU_Class0 - 0.9334, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.50it/s, dice_loss - 0.05077, iou_score - 0.906, IoU_Class0 - 0.9448, IoU_Class Epoch: 35 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.05804, iou_score - 0.8929, IoU_Class0 - 0.9331, IoU_Cl valid: 100%|█| 51/51 [00:09<00:00, 5.53it/s, dice_loss - 0.04548, iou_score - 0.9151, IoU_Class0 - 0.9511, IoU_Clas Epoch: 36 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.0571, iou_score - 0.8945, IoU_Class0 - 0.934, IoU_Clas valid: 100%|█| 51/51 [00:09<00:00, 5.55it/s, dice_loss - 0.04835, iou_score - 0.9095, IoU_Class0 - 0.9489, IoU_Clas Epoch: 37 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.0565, iou_score - 0.8954, IoU_Class0 - 0.9345, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.52it/s, dice_loss - 0.03977, iou_score - 0.9248, IoU_Class0 - 0.9573, IoU_Clas Epoch: 38 train: 100%|█| 402/402 [02:32<00:00, 2.63it/s, dice_loss - 0.05705, iou_score - 0.8946, IoU_Class0 - 0.934, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.04505, iou_score - 0.9157, IoU_Class0 - 0.952, IoU_Class Epoch: 39 train: 100%|█| 402/402 [02:32<00:00, 2.64it/s, dice_loss - 0.05678, iou_score - 0.8948, IoU_Class0 - 0.934, IoU_Cla valid: 100%|█| 51/51 [00:09<00:00, 5.51it/s, dice_loss - 0.04338, iou_score - 0.9185, IoU_Class0 - 0.9535, IoU_Clas
for metric in valid_logs_vec[0].keys():
train_metric_vec = [m[metric] for m in train_logs_vec]
valid_metric_vec = [m[metric] for m in valid_logs_vec]
plt.plot(train_metric_vec)
plt.plot(valid_metric_vec)
plt.legend(['train', 'valid'])
plt.xlabel('epoch')
plt.ylabel(metric)
plt.grid()
plt.show()
# ax = plt.gca()
# ax.set_yscale('log')
model = best_model
model.to(DEVICE)
test_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
test_epoch.run(test_loader)
valid: 100%|█| 51/51 [00:09<00:00, 5.47it/s, dice_loss - 0.04163, iou_score - 0.922, IoU_Class0 - 0.9541, IoU_Class
{'dice_loss': 0.04163072974074121,
'iou_score': 0.9219955813650991,
'IoU_Class0': 0.9541221237650105,
'IoU_Class1': 0.5101585100863784,
'IoU_Class2': 0.948488014912629,
'fscore': 0.9583766986342036,
'accuracy': 0.9585827472163183,
'recall': 0.9583657500790611,
'precision': 0.9583876635514053}
model = model.to('cpu')
# torch.save(model.state_dict(), '/media/data/local/corn/results/train_all_files_768_stride/2/model_cpu')
vi = iter(test_loader)
# vi = iter(train_loader)
for i in range(8): # increase to get more images
img_batch, mask_batch = next(vi)
with torch.no_grad():
# model_output = model(img_batch.to(DEVICE))
model_output = model(img_batch)
columns = 5
rows = len(img_batch)
fig = plt.figure(figsize=(columns * 4, rows * 4))
for i in range(len(img_batch)):
fig.add_subplot(rows, columns, 1 + i*columns + 0)
plt.imshow(img_batch[i].numpy().transpose([1, 2, 0]))
plt.axis('off')
plt.title('img')
fig.add_subplot(rows, columns, 1 + i*columns + 1)
plt.imshow(mask_batch[i][1].numpy())
plt.axis('off')
plt.title('original damage mask')
fig.add_subplot(rows, columns, 1 + i*columns + 2)
plt.imshow(model_output[i][1])
plt.axis('off')
plt.title('prediction damage')
fig.add_subplot(rows, columns, 1 + i*columns + 3)
cax = plt.imshow(model_output[i][1] - mask_batch[i][1], vmin=-1.1, vmax=1.1)
plt.title('damage diff (predict-gt)')
plt.axis('off')
cbar = fig.colorbar(cax, ticks=[-1, 0, 1])
cbar.ax.set_yticklabels(['false negative', 'true', 'false positive'])
fig.add_subplot(rows, columns, 1 + i*columns + 4)
plt.imshow(model_output[i][0])
plt.title('prediction healty field')
plt.axis('off')
plt.show()
device = 'cpu'
# device = DEVICE
model = model.to(device)
number_of_batches = len(valid_loader)
healthy_field_ground_truth_pix = 0
damage_ground_truth_pix = 0
healthy_field_predicted_pix = 0
damage_field_predicted_pix = 0
damage_prediction_true_positives_pix = 0
healthy_intersection_pix = 0
healthy_union_pix = 0
damage_intersection_pix = 0
damage_union_pix = 0
for i, (img_batch, mask_batch) in enumerate(test_loader):
print(f'Batch {i} / {number_of_batches}')
with torch.no_grad():
model_output = model(img_batch.to(device)).to(CPU_DEVICE)
for i in range(model_output.shape[0]):
ground_truth_healthy_field = mask_batch[i, 0, :, :].numpy().astype(int)
ground_truth_damage = mask_batch[i, 1, :, :].numpy().astype(int)
predicted_healty_field = model_output[i, 0, :, :].numpy()
predicted_damage = model_output[i, 1, :, :].numpy()
predicted_healty_field = np.where(predicted_healty_field > 0.5, 1, 0)
predicted_damage = np.where(predicted_damage > 0.5, 1, 0)
healthy_field_ground_truth_pix += np.count_nonzero(ground_truth_healthy_field)
damage_ground_truth_pix += np.count_nonzero(ground_truth_damage)
healthy_field_predicted_pix += np.count_nonzero(predicted_healty_field)
damage_field_predicted_pix += np.count_nonzero(predicted_damage)
common_damage = np.logical_and(ground_truth_damage, predicted_damage)
damage_prediction_true_positives_pix += np.count_nonzero(common_damage)
common_healthy = np.logical_and(ground_truth_healthy_field, predicted_healty_field)
damage_intersection_pix += np.count_nonzero(common_damage)
healthy_intersection_pix += np.count_nonzero(common_healthy)
damage_union_pix += np.count_nonzero(np.logical_or(ground_truth_damage, predicted_damage))
healthy_union_pix += np.count_nonzero(np.logical_or(ground_truth_healthy_field, predicted_healty_field))
total_ground_truth_pix = healthy_field_ground_truth_pix + damage_ground_truth_pix
total_predicted_pix = healthy_field_predicted_pix + damage_field_predicted_pix
iou_damage = damage_intersection_pix / damage_union_pix
iou_healthy = healthy_intersection_pix / healthy_union_pix
print(f'healthy_field_ground_truth = {healthy_field_ground_truth_pix / total_ground_truth_pix * 100:.2f} %')
print(f'damage_ground_truth = {damage_ground_truth_pix / total_ground_truth_pix * 100:.2f} %')
print(f'healthy_field_predicted = {healthy_field_predicted_pix / total_predicted_pix * 100:.2f} %')
print(f'damage_field_predicted = {damage_field_predicted_pix / total_predicted_pix * 100:.2f} %')
print(f'damage_prediction_true_positives/damage_field_predicted = {damage_prediction_true_positives_pix / damage_field_predicted_pix * 100:.2f} %')
print(f'iou_damage = {iou_damage:.3f}')
print(f'iou_healthy = {iou_healthy:.3f}')
Batch 0 / 51 Batch 1 / 51 Batch 2 / 51 Batch 3 / 51 Batch 4 / 51 Batch 5 / 51 Batch 6 / 51 Batch 7 / 51 Batch 8 / 51 Batch 9 / 51 Batch 10 / 51 Batch 11 / 51 Batch 12 / 51 Batch 13 / 51 Batch 14 / 51 Batch 15 / 51 Batch 16 / 51 Batch 17 / 51 Batch 18 / 51 Batch 19 / 51 Batch 20 / 51 Batch 21 / 51 Batch 22 / 51 Batch 23 / 51 Batch 24 / 51 Batch 25 / 51 Batch 26 / 51 Batch 27 / 51 Batch 28 / 51 Batch 29 / 51 Batch 30 / 51 Batch 31 / 51 Batch 32 / 51 Batch 33 / 51 Batch 34 / 51 Batch 35 / 51 Batch 36 / 51 Batch 37 / 51 Batch 38 / 51 Batch 39 / 51 Batch 40 / 51 Batch 41 / 51 Batch 42 / 51 Batch 43 / 51 Batch 44 / 51 Batch 45 / 51 Batch 46 / 51 Batch 47 / 51 Batch 48 / 51 Batch 49 / 51 Batch 50 / 51 healthy_field_ground_truth = 91.40 % damage_ground_truth = 8.60 % healthy_field_predicted = 92.07 % damage_field_predicted = 7.93 % damage_prediction_true_positives/damage_field_predicted = 77.80 % iou_damage = 0.596 iou_healthy = 0.955